import traceback
from typing import Dict, List, Optional

from trinity.buffer.buffer import BufferWriter, get_buffer_reader, get_buffer_writer
from trinity.buffer.operators.experience_operator import ExperienceOperator
from trinity.buffer.storage.queue import is_database_url, is_json_file
from trinity.common.config import (
    AlgorithmConfig,
    BufferConfig,
    Config,
    ExperiencePipelineConfig,
    StorageConfig,
)
from trinity.common.constants import StorageType
from trinity.common.experience import Experience
from trinity.utils.log import get_logger
from trinity.utils.plugin_loader import load_plugins


def get_input_buffers(
    pipeline_config: ExperiencePipelineConfig, buffer_config: BufferConfig
) -> Dict:
    """Get input buffers for the experience pipeline."""
    input_buffers = {}
    for input_name, input_config in pipeline_config.inputs.items():
        buffer_reader = get_buffer_reader(input_config, buffer_config)
        input_buffers[input_name] = buffer_reader
    return input_buffers


class ExperiencePipeline:
    """
    A class to process experiences.
    """

    def __init__(self, config: Config):
        self.logger = get_logger(f"{config.explorer.name}_experience_pipeline", in_ray_actor=True)
        load_plugins()
        pipeline_config = config.data_processor.experience_pipeline
        buffer_config = config.buffer
        self.input_store = self._init_input_storage(pipeline_config, buffer_config)  # type: ignore [arg-type]
        try:
            self.operators = ExperienceOperator.create_operators(pipeline_config.operators)
        except Exception as e:
            self.logger.error(f"Failed to create experience operators: {traceback.format_exc()}")
            raise e
        self._set_algorithm_operators(config.algorithm)
        self.output = get_buffer_writer(
            buffer_config.trainer_input.experience_buffer,  # type: ignore [arg-type]
            buffer_config,
        )

    def _init_input_storage(
        self,
        pipeline_config: ExperiencePipelineConfig,
        buffer_config: BufferConfig,
    ) -> Optional[BufferWriter]:
        """Initialize the input storage if it is not already set."""
        if pipeline_config.save_input:
            if pipeline_config.input_save_path is None:
                raise ValueError("input_save_path must be set when save_input is True.")
            elif is_json_file(pipeline_config.input_save_path):
                return get_buffer_writer(
                    StorageConfig(
                        storage_type=StorageType.FILE,
                        path=pipeline_config.input_save_path,
                        wrap_in_ray=False,
                    ),
                    buffer_config,
                )
            elif is_database_url(pipeline_config.input_save_path):
                return get_buffer_writer(
                    StorageConfig(
                        storage_type=StorageType.SQL,
                        path=pipeline_config.input_save_path,
                        wrap_in_ray=False,
                    ),
                    buffer_config,
                )
            else:
                raise ValueError(
                    f"Unsupported save_input format: {pipeline_config.save_input}. "
                    "Only JSON file path or SQLite URL is supported."
                )
        else:
            return None

    def _set_algorithm_operators(self, algorithm_config: AlgorithmConfig) -> None:
        """Add algorithm-specific operators to the pipeline."""
        from trinity.algorithm import ADVANTAGE_FN, ALGORITHM_TYPE

        algorithm = ALGORITHM_TYPE.get(algorithm_config.algorithm_type)
        if not algorithm.compute_advantage_in_trainer and algorithm_config.advantage_fn:
            advantage_fn_cls = ADVANTAGE_FN.get(algorithm_config.advantage_fn)
            assert (
                advantage_fn_cls is not None
            ), f"AdvantageFn {algorithm_config.advantage_fn} not found."
            assert (
                not advantage_fn_cls.compute_in_trainer()
            ), f"AdvantageFn {algorithm_config.advantage_fn} can only be computed in the trainer, please check your implementation."
            self.operators.append(advantage_fn_cls(**algorithm_config.advantage_fn_args))

    async def prepare(self) -> None:
        await self.output.acquire()

    async def process(self, exps: List[Experience]) -> Dict:
        """Process a batch of experiences.

        Args:
            exps (List[Experience]): List of experiences to process. These experiences are typically generated by an explorer in one step.

        Returns:
            Dict: A dictionary containing metrics collected during the processing of experiences.
        """
        if self.input_store is not None:
            await self.input_store.write_async(exps)
        metrics = {}

        # Process experiences through operators
        for operator in self.operators:
            exps, metric = operator.process(exps)
            metrics.update(metric)

        metrics["experience_count"] = len(exps)

        # Write processed experiences to output buffer
        await self.output.write_async(exps)

        # prefix metrics keys with 'pipeline/'
        result_metrics = {}
        for key, value in metrics.items():
            if isinstance(value, (int, float)):
                result_metrics[f"pipeline/{key}"] = float(value)

        return result_metrics

    async def close(self) -> None:
        try:
            await self.output.release()
        except Exception as e:
            self.logger.error(f"Failed to release output buffer: {e}")
        for operator in self.operators:
            operator.close()
